{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "customInput": null,
    "hidden_ranges": [],
    "originalKey": "8ca7695d-8a19-454e-b32b-3d5c36d52faf",
    "showInput": false
   },
   "source": [
    "The purpose of this example is to demostrate the overall flow of lowering a PyTorch model\n",
    "to TensorRT conveniently with lower.py. We integrated the transformation process including `TRTInterpreter`, `TRTModule`, pass optimization into the `lower_to_trt` API, users are encouraged to check the docstring of the API and tune it to meet your needs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "customInput": null,
    "customOutput": null,
    "executionStartTime": 1661189891682,
    "executionStopTime": 1661189891856,
    "originalKey": "7db2accc-9fa4-4a1e-8142-d887f2947bcd",
    "requestMsgId": "b5d8efce-0963-4074-bc9d-e8e1a78fd424",
    "showInput": true
   },
   "outputs": [],
   "source": [
    "import typing as t\n",
    "from copy import deepcopy\n",
    "from dataclasses import dataclass, field, replace\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torch_tensorrt.fx.lower import compile\n",
    "from torch_tensorrt.fx.utils import LowerPrecision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "customInput": null,
    "hidden_ranges": [],
    "originalKey": "e324a1ff-1bc2-4e78-932f-33534c3ac3f5",
    "showInput": false
   },
   "source": [
    "Specify the `configuration` class used for FX path lowering and benchmark. To extend, add a new configuration field to this class, and modify the lowering or benchmark behavior in `run_configuration_benchmark()` correspondingly. It automatically stores all its values to a `Result` dataclass.   \n",
    "`Result` is another dataclass that holds raw essential benchmark result values like Batch size, QPS, accuracy, etc..\n",
    ""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "code_folding": [],
    "collapsed": false,
    "customInput": null,
    "customOutput": null,
    "executionStartTime": 1661189260550,
    "executionStopTime": 1661189262039,
    "hidden_ranges": [],
    "originalKey": "2835fffa-cc50-479a-9080-c4f7002c0726",
    "requestMsgId": "6ea72dbf-dbfe-451e-8613-15f87e34a1a5",
    "showInput": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 102740.872 _utils_internal.py:179] NCCL_DEBUG env var is set to None\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 102740.873 _utils_internal.py:188] NCCL_DEBUG is INFO from /etc/nccl.conf\n"
     ]
    }
   ],
   "source": [
    "@dataclass\n",
    "class Configuration:\n",
    "    # number of inferences to run\n",
    "    batch_iter: int\n",
    "\n",
    "    # Input batch size\n",
    "    batch_size: int\n",
    "\n",
    "    # Friendly name of the configuration\n",
    "    name: str = \"\"\n",
    "\n",
    "    # Whether to apply TRT lowering to the model before benchmarking\n",
    "    trt: bool = False\n",
    "\n",
    "    # Whether to apply engine holder to the lowered model\n",
    "    jit: bool = False\n",
    "\n",
    "    # Whether to enable FP16 mode for TRT lowering\n",
    "    fp16: bool = False\n",
    "\n",
    "    # Relative tolerance for accuracy check after lowering. -1 means do not\n",
    "    # check accuracy.\n",
    "    accuracy_rtol: float = -1  # disable\n",
    "\n",
    "@dataclass\n",
    "class Result:\n",
    "    module: torch.nn.Module = field(repr=False)\n",
    "    input: t.Any = field(repr=False)\n",
    "    conf: Configuration\n",
    "    time_sec: float\n",
    "    accuracy_res: t.Optional[bool] = None\n",
    "\n",
    "    @property\n",
    "    def time_per_iter_ms(self) -> float:\n",
    "        return self.time_sec * 1.0e3\n",
    "\n",
    "    @property\n",
    "    def qps(self) -> float:\n",
    "        return self.conf.batch_size / self.time_sec\n",
    "\n",
    "    def format(self) -> str:\n",
    "        return (\n",
    "            f\"== Benchmark Result for: {self.conf}\\n\"\n",
    "            f\"BS: {self.conf.batch_size}, \"\n",
    "            f\"Time per iter: {self.time_per_iter_ms:.2f}ms, \"\n",
    "            f\"QPS: {self.qps:.2f}, \"\n",
    "            f\"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "code_folding": [],
    "customInput": null,
    "hidden_ranges": [],
    "originalKey": "3e462cf6-d282-402d-955b-a3ecb400bf0b",
    "showInput": false
   },
   "source": [
    "Run FX path lowering and benchmark the given model according to the specified benchmark configuration. Prints the benchmark result for each configuration at the end of the run. `benchmark_torch_function` is the actual function that computes the fixed number of iterations of functions runs.\n",
    "The FX path lowering and TensorRT engine creation is integrated into `compile()` API which is defined in `fx/lower.py` file.\n",
    "It is good to list it out and show the usage of it. It takes in original module, input and lowering setting, run lowering workflow to turn module into a executable TRT engine \n",
    "```\n",
    "def compile(\n",
    "    module: nn.Module,\n",
    "    input: ,\n",
    "    max_batch_size: int = 2048,\n",
    "    max_workspace_size=1 << 25,\n",
    "    explicit_batch_dimension=False,\n",
    "    lower_precision=LowerPrecision.FP16,\n",
    "    verbose_log=False,\n",
    "    timing_cache_prefix=\"\",\n",
    "    save_timing_cache=False,\n",
    "    cuda_graph_batch_size=-1,\n",
    "    dynamic_batch=False,\n",
    ") -> nn.Module:\n",
    "``` \n",
    "\n",
    "    Args:\n",
    "        module: Original module for lowering.\n",
    "        input: Input for module.\n",
    "        max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)\n",
    "        max_workspace_size: Maximum size of workspace given to TensorRT.\n",
    "        explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.\n",
    "        lower_precision: lower_precision config given to TRTModule.\n",
    "        verbose_log: Enable verbose log for TensorRT if set True.\n",
    "        timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.\n",
    "        save_timing_cache: Update timing cache with current timing cache data if set to True.\n",
    "        cuda_graph_batch_size: Cuda graph batch size, default to be -1.\n",
    "        dynamic_batch: batch dimension (dim=0) is dynamic.\n",
    "\n",
    "    Returns:\n",
    "        A torch.nn.Module lowered by TensorRT.\n",
    "We testd a resnet18 network with input size of [128,3,224,224] for [Batch, Channel, Width, Height]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "code_folding": [],
    "collapsed": false,
    "customInput": null,
    "customOutput": null,
    "executionStartTime": 1661189697773,
    "executionStopTime": 1661189753875,
    "hidden_ranges": [],
    "originalKey": "3002935b-b95a-4a08-a57f-f7a35485af5b",
    "requestMsgId": "dc73f2d0-427b-4f71-bec1-b118cc5642d0",
    "showInput": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103458.189 manifold.py:1435] URL manifold://torchvision/tree/models/resnet18-f37072fd.pth was already cached in /home/wwei6/.torch/iopath_cache/manifold_cache/tree/models/resnet18-f37072fd.pth\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='CUDA Eager', trt=False, jit=False, fp16=False, accuracy_rtol=-1) green\n== Start benchmark iterations\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== End benchmark iterations\n=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='TRT FP32 Eager', trt=True, jit=False, fp16=False, accuracy_rtol=0.001) green\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103501.297 pass_utils.py:166] == Log pass <function fuse_permute_matmul at 0x7f787a0e08b0> before/after graph to /tmp/tmpe_7p37fq\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103501.390 pass_utils.py:166] == Log pass <function fuse_permute_linear at 0x7f787a0e0670> before/after graph to /tmp/tmpg_a347f0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103501.509 lower_pass_manager_builder.py:151] Now lowering submodule _run_on_acc_0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103501.511 lower.py:89] split_name='_run_on_acc_0' self.lower_setting.input_specs=[InputTensorSpec(shape=torch.Size([128, 3, 224, 224]), dtype=torch.float32, device=device(type='cuda', index=0), shape_ranges=[], has_batch_dim=True)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\nSupported node types in the model:\nacc_ops.conv2d: ((), {'input': torch.float32, 'weight': torch.float32})\nacc_ops.batch_norm: ((), {'input': torch.float32, 'running_mean': torch.float32, 'running_var': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\nacc_ops.relu: ((), {'input': torch.float32})\nacc_ops.max_pool2d: ((), {'input': torch.float32})\nacc_ops.add: ((), {'input': torch.float32, 'other': torch.float32})\nacc_ops.adaptive_avg_pool2d: ((), {'input': torch.float32})\nacc_ops.flatten: ((), {'input': torch.float32})\nacc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})\n\nUnsupported node types in the model:\n\nGot 1 acc subgraphs and 0 non-acc subgraphs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103503.964 fx2trt.py:204] Run Module elapsed time: 0:00:00.435984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103520.647 fx2trt.py:258] Build TRT engine elapsed time: 0:00:16.681226\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103520.658 lower_pass_manager_builder.py:168] Lowering submodule _run_on_acc_0 elapsed time 0:00:19.147071\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== Start benchmark iterations\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== End benchmark iterations\n=== Running benchmark for: Configuration(batch_iter=50, batch_size=128, name='TRT FP16 Eager', trt=True, jit=False, fp16=True, accuracy_rtol=0.01) green\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103523.067 pass_utils.py:166] == Log pass <function fuse_permute_matmul at 0x7f787a0e08b0> before/after graph to /tmp/tmpgphlicna\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103523.106 pass_utils.py:166] == Log pass <function fuse_permute_linear at 0x7f787a0e0670> before/after graph to /tmp/tmpy9cumddi\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103523.173 lower_pass_manager_builder.py:151] Now lowering submodule _run_on_acc_0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103523.174 lower.py:89] split_name='_run_on_acc_0' self.lower_setting.input_specs=[InputTensorSpec(shape=torch.Size([128, 3, 224, 224]), dtype=torch.float16, device=device(type='cuda', index=0), shape_ranges=[], has_batch_dim=True)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\nSupported node types in the model:\nacc_ops.conv2d: ((), {'input': torch.float16, 'weight': torch.float16})\nacc_ops.batch_norm: ((), {'input': torch.float16, 'running_mean': torch.float16, 'running_var': torch.float16, 'weight': torch.float16, 'bias': torch.float16})\nacc_ops.relu: ((), {'input': torch.float16})\nacc_ops.max_pool2d: ((), {'input': torch.float16})\nacc_ops.add: ((), {'input': torch.float16, 'other': torch.float16})\nacc_ops.adaptive_avg_pool2d: ((), {'input': torch.float16})\nacc_ops.flatten: ((), {'input': torch.float16})\nacc_ops.linear: ((), {'input': torch.float16, 'weight': torch.float16, 'bias': torch.float16})\n\nUnsupported node types in the model:\n\nGot 1 acc subgraphs and 0 non-acc subgraphs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103523.466 fx2trt.py:204] Run Module elapsed time: 0:00:00.288043\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103553.687 fx2trt.py:258] Build TRT engine elapsed time: 0:00:30.220316\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0822 103553.698 lower_pass_manager_builder.py:168] Lowering submodule _run_on_acc_0 elapsed time 0:00:30.523791\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== Start benchmark iterations\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== End benchmark iterations\n== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='CUDA Eager', trt=False, jit=False, fp16=False, accuracy_rtol=-1)\nBS: 128, Time per iter: 14.66ms, QPS: 8732.53, Accuracy: None (rtol=-1)\n== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='TRT FP32 Eager', trt=True, jit=False, fp16=False, accuracy_rtol=0.001)\nBS: 128, Time per iter: 7.27ms, QPS: 17595.70, Accuracy: None (rtol=0.001)\n== Benchmark Result for: Configuration(batch_iter=50, batch_size=128, name='TRT FP16 Eager', trt=True, jit=False, fp16=True, accuracy_rtol=0.01)\nBS: 128, Time per iter: 4.49ms, QPS: 28480.34, Accuracy: None (rtol=0.01)\n"
     ]
    }
   ],
   "source": [
    "def benchmark_torch_function(iters: int, f, *args) -> float:\n",
    "    \"\"\"Estimates the average time duration for a single inference call in second\n",
    "\n",
    "    If the input is batched, then the estimation is for the batches inference call.\n",
    "    \"\"\"\n",
    "    with torch.inference_mode():\n",
    "        f(*args)\n",
    "    torch.cuda.synchronize()\n",
    "    start_event = torch.cuda.Event(enable_timing=True)\n",
    "    end_event = torch.cuda.Event(enable_timing=True)\n",
    "    print(\"== Start benchmark iterations\")\n",
    "    with torch.inference_mode():\n",
    "        start_event.record()\n",
    "        for _ in range(iters):\n",
    "            f(*args)\n",
    "        end_event.record()\n",
    "    torch.cuda.synchronize()\n",
    "    print(\"== End benchmark iterations\")\n",
    "    return (start_event.elapsed_time(end_event) * 1.0e-3) / iters\n",
    "\n",
    "\n",
    "def run_configuration_benchmark(\n",
    "    module,\n",
    "    input,\n",
    "    conf: Configuration,\n",
    ") -> Result:\n",
    "    print(f\"=== Running benchmark for: {conf}\", \"green\")\n",
    "    time = -1.0\n",
    "\n",
    "    if conf.fp16:\n",
    "        module = module.half()\n",
    "        input = [i.half() for i in input]\n",
    "\n",
    "    if not conf.trt:\n",
    "        # Run eager mode benchmark\n",
    "        time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))\n",
    "    elif not conf.jit:\n",
    "        # Run lowering eager mode benchmark\n",
    "        lowered_module = compile(\n",
    "            module,\n",
    "            input,\n",
    "            max_batch_size=conf.batch_size,\n",
    "            lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,\n",
    "        )\n",
    "        time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))\n",
    "    else:\n",
    "        print(\"Lowering with JIT is not available!\", \"red\")\n",
    "\n",
    "    result = Result(module=module, input=input, conf=conf, time_sec=time)\n",
    "    return result\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def benchmark(\n",
    "    model,\n",
    "    inputs,\n",
    "    batch_iter: int,\n",
    "    batch_size: int,\n",
    ") -> None:\n",
    "    model = model.cuda().eval()\n",
    "    inputs = [x.cuda() for x in inputs]\n",
    "\n",
    "    # benchmark base configuration\n",
    "    conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)\n",
    "\n",
    "    configurations = [\n",
    "        # Baseline\n",
    "        replace(conf, name=\"CUDA Eager\", trt=False),\n",
    "        # FP32\n",
    "        replace(\n",
    "            conf,\n",
    "            name=\"TRT FP32 Eager\",\n",
    "            trt=True,\n",
    "            jit=False,\n",
    "            fp16=False,\n",
    "            accuracy_rtol=1e-3,\n",
    "        ),\n",
    "        # FP16\n",
    "        replace(\n",
    "            conf,\n",
    "            name=\"TRT FP16 Eager\",\n",
    "            trt=True,\n",
    "            jit=False,\n",
    "            fp16=True,\n",
    "            accuracy_rtol=1e-2,\n",
    "        ),\n",
    "    ]\n",
    "\n",
    "    results = [run_configuration_benchmark(deepcopy(model), inputs, conf_) for conf_ in configurations]\n",
    "\n",
    "    for res in results:\n",
    "        print(res.format())\n",
    "\n",
    "\n",
    "test_model = torchvision.models.resnet18(pretrained=True)\n",
    "input = [torch.rand(128, 3, 224, 224)]\n",
    "benchmark(test_model, input, 50, 128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "customInput": null,
    "originalKey": "80bbae99-41ff-4baa-94a5-12bf0c9938f3",
    "showInput": true
   },
   "source": [
    ""
   ]
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "dataExplorerConfig": {},
  "kernelspec": {
   "display_name": "dper3_pytorch (cinder)",
   "language": "python",
   "metadata": {
    "cinder_runtime": true,
    "fbpkg_supported": true,
    "is_prebuilt": true,
    "kernel_name": "bento_kernel_dper3_pytorch_cinder",
    "nightly_builds": false
   },
   "name": "bento_kernel_dper3_pytorch_cinder"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  },
  "last_base_url": "https://devgpu005.ftw6.facebook.com:8091/",
  "last_kernel_id": "5f014373-151c-4ee8-8939-4daab994d202",
  "last_msg_id": "687e81e8-4414f32c89cd026dd1ea3fd9_139",
  "last_server_session_id": "24a1a10c-29aa-4e2b-a11f-2b5108fc1e58",
  "outputWidgetContext": {}
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
